{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "STARTED RUNNING THIS NOTEBOOK\n"
     ]
    }
   ],
   "source": [
    "import preprocessor as p\n",
    "import numpy as np\n",
    "print(\"STARTED RUNNING THIS NOTEBOOK\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "from tqdm import tqdm_notebook as tqdm\n",
    "import preprocessor as p\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "from collections import Counter\n",
    "import spacy\n",
    "from tqdm import tqdm, tqdm_notebook, tnrange\n",
    "import pandas as pd\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, confusion_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle as pkl\n",
    "from collections import defaultdict\n",
    "import pandas as pd\n",
    "import os\n",
    "import numpy as np\n",
    "import json\n",
    "from tqdm import tqdm, tqdm_notebook\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import classification_report, f1_score, accuracy_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "import spacy\n",
    "from tqdm import tqdm, tqdm_notebook, tnrange\n",
    "import pandas as pd\n",
    "from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, confusion_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Node:\n",
    "    def __init__(self,uid,tid,time_stamp,label):\n",
    "        self.children = {}\n",
    "        self.childrenList = []\n",
    "        self.num_children = 0\n",
    "        self.tid = tid\n",
    "        self.uid = uid\n",
    "        self.label = label\n",
    "        self.time_stamp = time_stamp\n",
    "    \n",
    "    def add_child(self,node):\n",
    "        if node.uid not in self.children:\n",
    "            self.children[node.uid] = node\n",
    "            self.num_children += 1\n",
    "        else:\n",
    "            self.children[node.uid] = node\n",
    "        self.childrenList = list(self.children.values())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Tree:\n",
    "    def __init__(self,root):\n",
    "        self.root = root\n",
    "        self.tweet_id = root.tid\n",
    "        self.uid = root.uid\n",
    "        self.height = 0\n",
    "        self.nodes = 0\n",
    "    \n",
    "    def show(self):\n",
    "        queue = [self.root,0]\n",
    "        \n",
    "        while len(queue) != 0:\n",
    "            toprint = queue.pop(0)\n",
    "            if toprint == 0:\n",
    "                print('\\n')\n",
    "            else:\n",
    "                print(toprint.uid,end=' ')\n",
    "                queue += toprint.children.values()\n",
    "                queue.append(0)\n",
    "                \n",
    "    def insertnode(self,curnode,parent,child):\n",
    "        if curnode.uid == parent.uid:\n",
    "            curnode.add_child(child)\n",
    "            return 1\n",
    "\n",
    "        elif parent.uid in curnode.children:\n",
    "            s = self.insertnode(curnode.children[parent.uid],parent,child)\n",
    "            return 2\n",
    "        else:\n",
    "            for node in curnode.children:\n",
    "                s = self.insertnode(curnode.children[node],parent,child)\n",
    "                if s == 2:\n",
    "                    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loadPklFileNum(datapath,incSize,fileNum):\n",
    "    \n",
    "    with open(datapath+str(incSize)+'inc_'+str(fileNum)+'.pickle', 'rb') as handle:\n",
    "        twitTrees = pkl.load(handle)\n",
    "    return twitTrees"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loadTreeFilesOfIncrement(datapath,incSize):\n",
    "    twittertrees = {}\n",
    "    \n",
    "    files = [x for x in os.listdir(t15Datapath) if str(incSize)+'inc' in x]\n",
    "    \n",
    "    for file in tqdm(files):\n",
    "        with open(datapath+file,'rb') as handle:\n",
    "            partialTrees = pkl.load(handle)\n",
    "        twittertrees.update(partialTrees)\n",
    "        \n",
    "    return twittertrees"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'twitter15'\n",
    "\n",
    "if dataset == 'twitter15':\n",
    "    %run ../twitter15/twitter15_text_processing.ipynb\n",
    "    t15Datapath = '../twitter15/pickledTrees/'\n",
    "    \n",
    "if dataset == 'twitter16':\n",
    "    %run ../twitter16/twitter16_text_processing.ipynb\n",
    "    t15Datapath = '../twitter16/pickledTrees/'\n",
    "# twitter15_trees = loadPklFileNum(t15Datapath,20,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aac162b301944c0b88bfd1f079fd8c5d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=16), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "twitter15_trees = loadTreeFilesOfIncrement(t15Datapath,20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 34/34 [03:47<00:00,  8.14s/it]\n",
      "100%|██████████| 430343/430343 [01:27<00:00, 4921.27it/s]\n"
     ]
    }
   ],
   "source": [
    "if dataset == 'twitter15':\n",
    "    %run ../twitter15/userdata_parser.ipynb\n",
    "    with open('./kfolds_twitter15.pickle','rb') as f:\n",
    "        kfold_df = pkl.load(f)\n",
    "    datadir = './twitter15_kfold_results/'\n",
    "    \n",
    "if dataset == 'twitter16':\n",
    "    %run ../twitter16/userdata_parser.ipynb\n",
    "    with open('./kfolds_twitter16.pickle','rb') as f:\n",
    "        kfold_df = pkl.load(f)\n",
    "    datadir = './twitter16_kfold_results/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 430343/430343 [00:01<00:00, 296193.27it/s]\n"
     ]
    }
   ],
   "source": [
    "for key in tqdm(userVects):\n",
    "    userVects[key] = userVects[key].float()\n",
    "\n",
    "userVects = defaultdict(lambda:torch.tensor([1.1100e+02, 1.5000e+01, 0.0000e+00, 7.9700e+02, 4.7300e+02, 0.0000e+00,\n",
    "        8.3326e+04, 1.0000e+00]),userVects)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0306 14:46:22.169637 140318102980416 file_utils.py:41] PyTorch version 1.2.0 available.\n"
     ]
    }
   ],
   "source": [
    "%run ./textEncoders.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "%run ./temporal_tree_model.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    device = 'cuda:1'\n",
    "    device = 'cpu'\n",
    "else:\n",
    "    device = 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# labelMap = {}\n",
    "# labelCount = 0\n",
    "# for label in list(twitter15_labels.values()):\n",
    "#     if label not in labelMap:\n",
    "#         labelMap[label] = labelCount\n",
    "#         labelCount += 1\n",
    "# labelMap\n",
    "\n",
    "labelMap = {'true':0,'false':1,'unverified':2,'non-rumor':3}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def trainModelWithKfold(foldnum,datadir,model,modelname,numtrees):\n",
    "    train_ids, valid_ids = kfold_df[foldnum]\n",
    "    \n",
    "    x_train = []\n",
    "    x_test = []\n",
    "    y_train = []\n",
    "    y_test = []\n",
    "\n",
    "    for tid in train_ids:\n",
    "        x_train.append(tuple((twitter15_trees[tid],twitter15_text[tid])))\n",
    "        y_train.append(labelMap[twitter15_labels[tid]])\n",
    "\n",
    "    for tid in valid_ids:\n",
    "        x_test.append(tuple((twitter15_trees[tid],twitter15_text[tid])))\n",
    "        y_test.append(labelMap[twitter15_labels[tid]])\n",
    "    \n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "    \n",
    "    f = lambda m, n: [(i*n//m + n//(2*m)) for i in range(m)]\n",
    "    \n",
    "    optimizer = torch.optim.Adagrad(model.parameters(),lr=0.01)\n",
    "\n",
    "    count = 0\n",
    "    maxAcc = 0\n",
    "    bestcr = 0\n",
    "\n",
    "    train_iterwise = []\n",
    "    val_iterwise = []\n",
    "\n",
    "    for i in range(3):\n",
    "        train_losses = []\n",
    "        val_losses = []\n",
    "\n",
    "        for treeSet,text in tqdm(x_train):\n",
    "            tnum = 0\n",
    "\n",
    "            idxs = f(numtrees,len(treeSet))\n",
    "    #         for idx in idxs:\n",
    "            trees = [ treeSet[idx] for idx in idxs ]\n",
    "            count += 1\n",
    "            tnum += 1\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            pred = model(trees)\n",
    "            label = Variable(torch.tensor(labelMap[trees[0].root.label]))\n",
    "\n",
    "            loss = criterion(pred.reshape(-1,4), label.reshape(-1))    \n",
    "\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "        preds = []\n",
    "        labels = []\n",
    "\n",
    "        allLabels = []\n",
    "        allPreds = []\n",
    "\n",
    "        with torch.no_grad():\n",
    "            for valSet,text in tqdm(x_test):\n",
    "                idxs = f(numtrees,len(valSet))\n",
    "                trees = [ valSet[idx] for idx in idxs ]\n",
    "\n",
    "                predicted = model(trees)\n",
    "                preds.append(predicted)\n",
    "\n",
    "                predicted =  torch.softmax(predicted[0],0)\n",
    "                predicted = torch.max(predicted, 0)[1].cpu().numpy().tolist()\n",
    "\n",
    "                labels.append(labelMap[trees[0].root.label])\n",
    "\n",
    "                allLabels.append(labelMap[trees[0].root.label])\n",
    "                allPreds.append(predicted)\n",
    "\n",
    "        predTensor = torch.stack(preds)\n",
    "        labelTensor = torch.tensor(labels).to(device)\n",
    "\n",
    "        print(allLabels,allPreds)\n",
    "\n",
    "        loss = criterion(predTensor.reshape(-1,4), labelTensor.reshape(-1))\n",
    "\n",
    "        cr = classification_report(allLabels,allPreds,output_dict=True)\n",
    "        cr['loss'] = loss.item()\n",
    "        cr['Acc'] = accuracy_score(allLabels,allPreds,)\n",
    "        print('loss: ',cr['loss'])\n",
    "        print(cr['Acc'])\n",
    "        cr['fold'] = foldnum\n",
    "\n",
    "        if cr['Acc'] > maxAcc:\n",
    "                maxAcc = cr['Acc']\n",
    "                bestcr = cr\n",
    "                torch.save({'state_dict': model.state_dict()}, datadir+modelname+'_'+str(foldnum)+'.pth')\n",
    "\n",
    "    with open(datadir+modelname+'.json', 'a') as fp:\n",
    "            json.dump(bestcr, fp)\n",
    "            fp.write('\\n')\n",
    "\n",
    "#     val_losses.append(loss.item())\n",
    "#     train_iterwise.append(np.array(train_losses).mean())\n",
    "#     val_iterwise.append(np.array(val_losses).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/1 [00:00<?, ?it/s]/home/gmanish/anaconda3/envs/py36/lib/python3.6/site-packages/ipykernel_launcher.py:17: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "100%|██████████| 1/1 [00:00<00:00,  1.39it/s]\n",
      "100%|██████████| 1/1 [00:00<00:00,  7.87it/s]\n",
      "/home/gmanish/anaconda3/envs/py36/lib/python3.6/site-packages/sklearn/metrics/classification.py:1143: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples.\n",
      "  'precision', 'predicted', average, warn_for)\n",
      "/home/gmanish/anaconda3/envs/py36/lib/python3.6/site-packages/sklearn/metrics/classification.py:1145: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples.\n",
      "  'recall', 'true', average, warn_for)\n",
      "  0%|          | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2] [1]\n",
      "loss:  2.253389596939087\n",
      "0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  3.18it/s]\n",
      "100%|██████████| 1/1 [00:00<00:00,  3.80it/s]\n",
      "  0%|          | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1] [1]\n",
      "loss:  0.7927819490432739\n",
      "1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  3.16it/s]\n",
      "100%|██████████| 1/1 [00:00<00:00,  1.02it/s]\n",
      "  0%|          | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2] [2]\n",
      "loss:  0.4069879949092865\n",
      "1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  2.94it/s]\n",
      "100%|██████████| 1/1 [00:00<00:00,  3.38it/s]\n",
      "  0%|          | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0] [2]\n",
      "loss:  5.676121234893799\n",
      "0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:01<00:00,  1.19s/it]\n",
      "100%|██████████| 1/1 [00:00<00:00,  3.54it/s]\n",
      "  0%|          | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2] [2]\n",
      "loss:  0.129058837890625\n",
      "1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  1.48it/s]\n",
      "100%|██████████| 1/1 [00:00<00:00,  1.98it/s]\n",
      "  0%|          | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1] [2]\n",
      "loss:  1.6049168109893799\n",
      "0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  1.43it/s]\n",
      "100%|██████████| 1/1 [00:01<00:00,  1.86s/it]\n",
      "  0%|          | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2] [2]\n",
      "loss:  0.002583739347755909\n",
      "1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  1.44it/s]\n",
      "100%|██████████| 1/1 [00:00<00:00,  2.48it/s]\n",
      "  0%|          | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0] [2]\n",
      "loss:  12.154277801513672\n",
      "0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:01<00:00,  1.81s/it]\n",
      "100%|██████████| 1/1 [00:00<00:00,  2.23it/s]\n",
      "  0%|          | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2] [2]\n",
      "loss:  0.00046456989366561174\n",
      "1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:01<00:00,  1.12s/it]\n",
      "100%|██████████| 1/1 [00:00<00:00,  1.26it/s]\n",
      "  0%|          | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1] [2]\n",
      "loss:  6.986588954925537\n",
      "0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:01<00:00,  1.06s/it]\n",
      "100%|██████████| 1/1 [00:02<00:00,  2.86s/it]\n",
      "  0%|          | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2] [2]\n",
      "loss:  4.60137271147687e-05\n",
      "1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:01<00:00,  1.08s/it]\n",
      "100%|██████████| 1/1 [00:00<00:00,  1.56it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0] [2]\n",
      "loss:  17.359285354614258\n",
      "0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "numtreearr = [15]\n",
    "\n",
    "for numtrees in numtreearr:\n",
    "    modelname = 'decay_tempTreeEnc_'+str(numtrees)\n",
    "    for i in range(1,4):\n",
    "        model = temporalDecayTreeEncoder(torch.cuda.is_available(),8,100,userVects,twitter15_labels,labelMap,criterion,device)\n",
    "        trainModelWithKfold(i,datadir,model,modelname,numtrees)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/1 [00:00<?, ?it/s]/home/gmanish/anaconda3/envs/py36/lib/python3.6/site-packages/ipykernel_launcher.py:17: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "100%|██████████| 1/1 [00:00<00:00,  1.70it/s]\n",
      "100%|██████████| 1/1 [00:00<00:00,  6.90it/s]\n",
      "  0%|          | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2] [1]\n",
      "loss:  1.8845772743225098\n",
      "0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  3.22it/s]\n",
      "100%|██████████| 1/1 [00:00<00:00,  3.76it/s]\n",
      "  0%|          | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1] [1]\n",
      "loss:  0.5592848062515259\n",
      "1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  3.02it/s]\n",
      "100%|██████████| 1/1 [00:00<00:00,  1.15it/s]\n",
      "  0%|          | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2] [2]\n",
      "loss:  0.7918633222579956\n",
      "1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  2.94it/s]\n",
      "100%|██████████| 1/1 [00:00<00:00,  5.00it/s]\n",
      "  0%|          | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0] [2]\n",
      "loss:  4.407657146453857\n",
      "0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:01<00:00,  1.07s/it]\n",
      "100%|██████████| 1/1 [00:00<00:00,  3.38it/s]\n",
      "  0%|          | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2] [2]\n",
      "loss:  0.1386573314666748\n",
      "1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  1.51it/s]\n",
      "100%|██████████| 1/1 [00:00<00:00,  1.95it/s]\n",
      "  0%|          | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1] [1]\n",
      "loss:  0.5178874731063843\n",
      "1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  1.45it/s]\n",
      "100%|██████████| 1/1 [00:01<00:00,  1.75s/it]\n",
      "  0%|          | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2] [1]\n",
      "loss:  1.549131989479065\n",
      "0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  1.48it/s]\n",
      "100%|██████████| 1/1 [00:00<00:00,  2.60it/s]\n",
      "  0%|          | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0] [1]\n",
      "loss:  2.8819949626922607\n",
      "0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:01<00:00,  1.64s/it]\n",
      "100%|██████████| 1/1 [00:00<00:00,  2.29it/s]\n",
      "  0%|          | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2] [2]\n",
      "loss:  4.255681051290594e-05\n",
      "1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:01<00:00,  1.02s/it]\n",
      "100%|██████████| 1/1 [00:00<00:00,  1.30it/s]\n",
      "  0%|          | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1] [1]\n",
      "loss:  0.07381763309240341\n",
      "1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:01<00:00,  1.04s/it]\n",
      "100%|██████████| 1/1 [00:02<00:00,  2.67s/it]\n",
      "  0%|          | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2] [1]\n",
      "loss:  2.3132643699645996\n",
      "0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:01<00:00,  1.05s/it]\n",
      "100%|██████████| 1/1 [00:00<00:00,  1.68it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0] [2]\n",
      "loss:  8.622230529785156\n",
      "0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "model = lstmTreeEncoder(torch.cuda.is_available(),8,100,userVects,twitter15_labels,labelMap,criterion,device)\n",
    "numtreearr = [10,15]\n",
    "modelname = 'std_tempTreeEnc_'+str(numtrees)\n",
    "\n",
    "for numtrees in numtreearr:\n",
    "    for i in range(0,4):\n",
    "        trainModelWithKfold(i,datadir,model,modelname,numtrees)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "epochs = 10\n",
    "X = []\n",
    "y = []\n",
    "X_text = []\n",
    "\n",
    "\n",
    "\n",
    "for tid in twitter15_trees:\n",
    "        if tid in twitter15_trees and tid in twitter15_labels:\n",
    "            X.append(tuple((twitter15_trees[tid],twitter15_text[tid])))\n",
    "            y.append(labelMap[twitter15_labels[tid]])\n",
    "            X_text.append(twitter15_text[tid])\n",
    "            \n",
    "x_train,x_test,y_train,y_test = train_test_split(X,y,random_state=2018)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# model.load_state_dict(torch.load('./pretrainedModels-Twit15/'+'std_tempTreeEnc_pretrained_inc20'+'.pth')['state_dict'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "rnnTypes = ['gru','lstm']\n",
    "bidirTypes = [True,False]\n",
    "\n",
    "for rnnType in rnnTypes:\n",
    "    for bidirType in bidirTypes:\n",
    "        textparams[bidir] = bidirType\n",
    "        textparams[rnnType] = rnnType\n",
    "        \n",
    "        textEncoderModel = TextEncoder('rnn',textparams,X_text,y,device)\n",
    "        textEncoderModel.trainModel()\n",
    "        print(len(textEncoderModel.word2idx))\n",
    "        \n",
    "        modelname = rnnType\n",
    "        if bidirType:\n",
    "            modelname = 'bidir'+modelname\n",
    "        \n",
    "        torch.save({'state_dict': textEncoderModel.optimalParams}, './pretrainedModels-Twit15/'+modelname+'.pth')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
